{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "267281b6",
   "metadata": {},
   "source": [
    "| [04_text_classification/03_Bert文本分类.ipynb](https://github.com/shibing624/nlp-tutorial/tree/main/04_text_classification/03_Bert文本分类.ipynb)  | 使用Bert模型finetune分类任务  |[Open In Colab](https://colab.research.google.com/github/shibing624/nlp-tutorial/blob/main/04_text_classification/03_Bert文本分类.ipynb) |\n",
    "\n",
    "# Bert文本分类\n",
    "\n",
    "基于transformers库，使用预训练模型对文本分类。预训练模型在专业领域数据上使用时需要finetune效果更好。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3e536b4",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/prateekjoshi565/Fine-Tuning-BERT/blob/master/Fine_Tuning_BERT_for_Spam_Classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c637738c",
   "metadata": {},
   "source": [
    "# Install Transformers Library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "08a820ec",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: transformers in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (4.10.3)\n",
      "Requirement already satisfied: packaging in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (20.9)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (2021.4.4)\n",
      "Requirement already satisfied: tqdm>=4.27 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (4.59.0)\n",
      "Requirement already satisfied: sacremoses in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (0.0.45)\n",
      "Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (0.10.3)\n",
      "Requirement already satisfied: filelock in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (3.0.12)\n",
      "Requirement already satisfied: huggingface-hub>=0.0.12 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (0.0.12)\n",
      "Requirement already satisfied: requests in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (2.25.1)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (5.4.1)\n",
      "Requirement already satisfied: numpy>=1.17 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (1.20.1)\n",
      "Requirement already satisfied: typing-extensions in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from huggingface-hub>=0.0.12->transformers) (3.7.4.3)\n",
      "Requirement already satisfied: pyparsing>=2.0.2 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from packaging->transformers) (2.4.7)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers) (1.26.4)\n",
      "Requirement already satisfied: chardet<5,>=3.0.2 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers) (4.0.0)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers) (2021.5.30)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers) (2.10)\n",
      "Requirement already satisfied: joblib in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from sacremoses->transformers) (1.0.1)\n",
      "Requirement already satisfied: click in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from sacremoses->transformers) (7.1.2)\n",
      "Requirement already satisfied: six in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from sacremoses->transformers) (1.15.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6d00598f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import classification_report\n",
    "import transformers\n",
    "from transformers import AutoModel, BertTokenizerFast\n",
    "\n",
    "# specify GPU\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "822c1530",
   "metadata": {},
   "source": [
    "# Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f3d6490d",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>Go until jurong point, crazy.. Available only ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>Ok lar... Joking wif u oni...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>Free entry in 2 a wkly comp to win FA Cup fina...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>U dun say so early hor... U c already then say...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>Nah I don't think he goes to usf, he lives aro...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   label                                               text\n",
       "0      0  Go until jurong point, crazy.. Available only ...\n",
       "1      0                      Ok lar... Joking wif u oni...\n",
       "2      1  Free entry in 2 a wkly comp to win FA Cup fina...\n",
       "3      0  U dun say so early hor... U c already then say...\n",
       "4      0  Nah I don't think he goes to usf, he lives aro..."
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(\"data/spamdata.csv\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6b126df6",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5572, 2)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "72e71187",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    0.865937\n",
       "1    0.134063\n",
       "Name: label, dtype: float64"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check class distribution\n",
    "df['label'].value_counts(normalize = True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f05e698e",
   "metadata": {},
   "source": [
    "# Split train dataset into train, validation and test sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0d1d2754",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "train_text, temp_text, train_labels, temp_labels = train_test_split(df['text'], df['label'],\n",
    "                                                                    random_state=2018,\n",
    "                                                                    test_size=0.3,\n",
    "                                                                    stratify=df['label'])\n",
    "\n",
    "# we will use temp_text and temp_labels to create validation and test set\n",
    "val_text, test_text, val_labels, test_labels = train_test_split(temp_text, temp_labels,\n",
    "                                                                random_state=2018,\n",
    "                                                                test_size=0.5,\n",
    "                                                                stratify=temp_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d63e5fd",
   "metadata": {},
   "source": [
    "# Import BERT Model and BERT Tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2c4b2310",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "# import BERT-base pretrained model\n",
    "bert = AutoModel.from_pretrained('bert-base-uncased')\n",
    "\n",
    "# Load the BERT tokenizer\n",
    "tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d64cf097",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# sample data\n",
    "text = [\"this is a bert model tutorial\", \"we will fine-tune a bert model\"]\n",
    "\n",
    "# encode text\n",
    "sent_id = tokenizer.batch_encode_plus(text, padding=True, return_token_type_ids=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d2fc0ffb",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': [[101, 2023, 2003, 1037, 14324, 2944, 14924, 4818, 102, 0], [101, 2057, 2097, 2986, 1011, 8694, 1037, 14324, 2944, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}\n"
     ]
    }
   ],
   "source": [
    "# output\n",
    "print(sent_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "156fcc88",
   "metadata": {},
   "source": [
    "# Tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0b56ecb3",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAUuElEQVR4nO3dfbBcd33f8fendnCAC7Zckzuq5EYio9DxQ5ugOy4thbkaM7UBB7kP7ohxErl1R5PUpNCSKXKZKflHM0ozZEqGQEaNGZSaclEMjFVct7iuVaYzGBcZgywbxwKrRrYjNzwYRBgncr79Y49gUe692gdp713/3q+ZO7v729/Z89nj9Wd3z55dpaqQJLXlr6x0AEnS5Fn+ktQgy1+SGmT5S1KDLH9JatD5Kx3gTC655JLasGHDUMt8//vf5+Uvf/m5CXSOmPncm7a8YOZJmbbMg+Q9ePDgn1TVq5acUFWr+m/z5s01rPvuu2/oZVaamc+9actbZeZJmbbMg+QFvljLdKu7fSSpQZa/JDXI8pekBln+ktQgy1+SGmT5S1KDLH9JapDlL0kNsvwlqUGr/ucdJmHDzrsGmnd091vPcRJJmgxf+UtSg85Y/kk+kuTZJA/3jf1Wkq8m+UqSTye5qO+6W5McSfJYkmv6xjcnOdRd9ztJctbvjSRpIIO88v8ocO1pY/cAV1TV3wT+CLgVIMllwDbg8m6ZDyU5r1vmw8AOYFP3d/ptSpIm5IzlX1WfA7512thnq+pkd/F+YH13fiuwUFXPV9UTwBHgqiRrgVdW1ee7X5v7A+D6s3QfJElDSq+LzzAp2QB8pqquWOS6/wJ8oqpuT/JB4P6qur277jbgbuAosLuq3tSNvwF4T1Vdt8T6dtB7l8Ds7OzmhYWFoe7UiRMnmJmZGXj+oaeeG2jelesuHCrHMIbNvBpMW+ZpywtmnpRpyzxI3i1bthysqrmlrh/raJ8k7wVOAh87NbTItFpmfFFVtQfYAzA3N1fz8/ND5Tpw4ADDLHPToEf73DhcjmEMm3k1mLbM05YXzDwp05b5bOQdufyTbAeuA66uH719OAZc2jdtPfB0N75+kXFJ0goY6VDPJNcC7wHeVlV/2nfVfmBbkguSbKT3we4DVfUM8L0kr+uO8vll4M4xs0uSRnTGV/5JPg7MA5ckOQa8j97RPRcA93RHbN5fVb9SVYeT7AMeobc76JaqeqG7qV+ld+TQS+l9DnD32b0rkqRBnbH8q+rtiwzftsz8XcCuRca/CPylD4wlSZPnN3wlqUGWvyQ1yPKXpAZZ/pLUIMtfkhpk+UtSgyx/SWqQ5S9JDbL8JalBlr8kNcjyl6QGWf6S1CDLX5IaZPlLUoMsf0lqkOUvSQ2y/CWpQZa/JDXI8pekBln+ktQgy1+SGmT5S1KDLH9JapDlL0kNsvwlqUFnLP8kH0nybJKH+8YuTnJPkse70zV9192a5EiSx5Jc0ze+Ocmh7rrfSZKzf3ckSYMY5JX/R4FrTxvbCdxbVZuAe7vLJLkM2AZc3i3zoSTndct8GNgBbOr+Tr9NSdKEnLH8q+pzwLdOG94K7O3O7wWu7xtfqKrnq+oJ4AhwVZK1wCur6vNVVcAf9C0jSZqw9Lr4DJOSDcBnquqK7vJ3quqivuu/XVVrknwQuL+qbu/GbwPuBo4Cu6vqTd34G4D3VNV1S6xvB713CczOzm5eWFgY6k6dOHGCmZmZgecfeuq5geZdue7CoXIMY9jMq8G0ZZ62vGDmSZm2zIPk3bJly8Gqmlvq+vPPcqbF9uPXMuOLqqo9wB6Aubm5mp+fHyrEgQMHGGaZm3beNdC8ozcOl2MYw2ZeDaYt87TlBTNPyrRlPht5Rz3a53i3K4fu9Nlu/Bhwad+89cDT3fj6RcYlSStg1PLfD2zvzm8H7uwb35bkgiQb6X2w+0BVPQN8L8nruqN8frlvGUnShJ1xt0+SjwPzwCVJjgHvA3YD+5LcDDwJ3ABQVYeT7AMeAU4Ct1TVC91N/Sq9I4deSu9zgLvP6j2RJA3sjOVfVW9f4qqrl5i/C9i1yPgXgSuGSidJOif8hq8kNcjyl6QGWf6S1CDLX5IaZPlLUoMsf0lqkOUvSQ2y/CWpQZa/JDXI8pekBln+ktQgy1+SGmT5S1KDLH9JapDlL0kNsvwlqUGWvyQ1yPKXpAZZ/pLUIMtfkhpk+UtSgyx/SWqQ5S9JDbL8JalBY5V/kn+V5HCSh5N8PMlPJrk4yT1JHu9O1/TNvzXJkSSPJblm/PiSpFGMXP5J1gH/EpirqiuA84BtwE7g3qraBNzbXSbJZd31lwPXAh9Kct548SVJoxh3t8/5wEuTnA+8DHga2Ars7a7fC1zfnd8KLFTV81X1BHAEuGrM9UuSRpCqGn3h5J3ALuAHwGer6sYk36mqi/rmfLuq1iT5IHB/Vd3ejd8G3F1VdyxyuzuAHQCzs7ObFxYWhsp14sQJZmZmBp5/6KnnBpp35boLh8oxjGEzrwbTlnna8oKZJ2XaMg+Sd8uWLQeram6p688fdeXdvvytwEbgO8AfJvnF5RZZZGzRZ56q2gPsAZibm6v5+fmhsh04cIBhlrlp510DzTt643A5hjFs5tVg2jJPW14w86RMW+azkXec3T5vAp6oqv9XVX8OfAr4u8DxJGsButNnu/nHgEv7ll9PbzeRJGnCxin/J4HXJXlZkgBXA48C+4Ht3ZztwJ3d+f3AtiQXJNkIbAIeGGP9kqQRjbzbp6q+kOQO4EHgJPAlertqZoB9SW6m9wRxQzf/cJJ9wCPd/Fuq6oUx80uSRjBy+QNU1fuA9502/Dy9dwGLzd9F7wPiidgw4L58SWqN3/CVpAZZ/pLUIMtfkhpk+UtSgyx/SWqQ5S9JDbL8JalBlr8kNcjyl6QGWf6S1CDLX5IaZPlLUoMsf0lqkOUvSQ2y/CWpQZa/JDXI8pekBln+ktQgy1+SGmT5S1KDLH9JapDlL0kNsvwlqUGWvyQ1yPKXpAaNVf5JLkpyR5KvJnk0yd9JcnGSe5I83p2u6Zt/a5IjSR5Lcs348SVJoxj3lf8HgP9WVX8D+FvAo8BO4N6q2gTc210myWXANuBy4FrgQ0nOG3P9kqQRjFz+SV4JvBG4DaCq/qyqvgNsBfZ20/YC13fntwILVfV8VT0BHAGuGnX9kqTRpapGWzD5OWAP8Ai9V/0HgXcCT1XVRX3zvl1Va5J8ELi/qm7vxm8D7q6qOxa57R3ADoDZ2dnNCwsLQ2U7ceIEMzMzHHrquVHu2pKuXHfhWb29fqcyT5NpyzxtecHMkzJtmQfJu2XLloNVNbfU9eePsf7zgdcCv1ZVX0jyAbpdPEvIImOLPvNU1R56TyzMzc3V/Pz8UMEOHDjA/Pw8N+28a6jlzuTojcPlGMapzNNk2jJPW14w86RMW+azkXecff7HgGNV9YXu8h30ngyOJ1kL0J0+2zf/0r7l1wNPj7F+SdKIRi7/qvpj4BtJXtMNXU1vF9B+YHs3th24szu/H9iW5IIkG4FNwAOjrl+SNLpxdvsA/BrwsSQvAb4O/FN6Tyj7ktwMPAncAFBVh5Pso/cEcRK4papeGHP9kqQRjFX+VfUQsNgHClcvMX8XsGucdUqSxuc3fCWpQZa/JDXI8pekBo37ga/GtKH7LsK7rzy57PcSju5+66QiSWqAr/wlqUGWvyQ1yPKXpAZZ/pLUID/wHcKGAX8ozg9nJa12vvKXpAZZ/pLUIMtfkhpk+UtSgyx/SWqQR/ucAx4VJGm185W/JDXI8pekBln+ktQgy1+SGmT5S1KDLH9JapDlL0kNsvwlqUGWvyQ1aOzyT3Jeki8l+Ux3+eIk9yR5vDtd0zf31iRHkjyW5Jpx1y1JGs3ZeOX/TuDRvss7gXurahNwb3eZJJcB24DLgWuBDyU57yysX5I0pLHKP8l64K3A7/cNbwX2duf3Atf3jS9U1fNV9QRwBLhqnPVLkkYz7iv//wD8G+Av+sZmq+oZgO70p7rxdcA3+uYd68YkSROWqhptweQ64C1V9S+SzAO/XlXXJflOVV3UN+/bVbUmye8Cn6+q27vx24D/WlWfXOS2dwA7AGZnZzcvLCwMle3EiRPMzMxw6KnnRrpvk3Llugt/mHH2pXD8B8vPXW1ObedpMW15wcyTMm2ZB8m7ZcuWg1U1t9T14/yk8+uBtyV5C/CTwCuT3A4cT7K2qp5JshZ4tpt/DLi0b/n1wNOL3XBV7QH2AMzNzdX8/PxQwQ4cOMD8/Dw3DfjTyivl6I0/yvjuK0/y/kNL/+c4euP8hFIN7tR2nhbTlhfMPCnTlvls5B15t09V3VpV66tqA70Pcv9nVf0isB/Y3k3bDtzZnd8PbEtyQZKNwCbggZGTS5JGdi7+MZfdwL4kNwNPAjcAVNXhJPuAR4CTwC1V9cI5WL8k6QzOSvlX1QHgQHf+m8DVS8zbBew6G+uUJI3Ob/hKUoMsf0lqkOUvSQ2y/CWpQZa/JDXI8pekBln+ktQgy1+SGmT5S1KDLH9JapDlL0kNsvwlqUGWvyQ1yPKXpAZZ/pLUIMtfkhpk+UtSgyx/SWqQ5S9JDbL8JalBlr8kNcjyl6QGWf6S1CDLX5IaZPlLUoNGLv8klya5L8mjSQ4neWc3fnGSe5I83p2u6Vvm1iRHkjyW5JqzcQckScM7f4xlTwLvrqoHk7wCOJjkHuAm4N6q2p1kJ7ATeE+Sy4BtwOXAXwP+R5KfraoXxrsLbdiw866B5h3d/dZznETSi8HIr/yr6pmqerA7/z3gUWAdsBXY203bC1zfnd8KLFTV81X1BHAEuGrU9UuSRpeqGv9Gkg3A54ArgCer6qK+675dVWuSfBC4v6pu78ZvA+6uqjsWub0dwA6A2dnZzQsLC0PlOXHiBDMzMxx66rkR79FkXLnuwh9mnH0pHP/B2bnNSTm1nafFtOUFM0/KtGUeJO+WLVsOVtXcUtePs9sHgCQzwCeBd1XVd5MsOXWRsUWfeapqD7AHYG5urubn54fKdODAAebn57lpwF0lK+XojT/K+O4rT/L+Q2P/5+DojfNj38agTm3naTFtecHMkzJtmc9G3rGO9knyE/SK/2NV9alu+HiStd31a4Fnu/FjwKV9i68Hnh5n/ZKk0YxztE+A24BHq+q3+67aD2zvzm8H7uwb35bkgiQbgU3AA6OuX5I0unH2M7we+CXgUJKHurF/C+wG9iW5GXgSuAGgqg4n2Qc8Qu9IoVs80keSVsbI5V9V/5vF9+MDXL3EMruAXaOuU5J0dvgNX0lqkOUvSQ2y/CWpQZa/JDXI8pekBln+ktQgy1+SGmT5S1KDLH9JapDlL0kNsvwlqUGWvyQ1yPKXpAZZ/pLUIMtfkhpk+UtSgyx/SWrQOP+Mo1ahDTvvGmje0d1vPcdJJK1mvvKXpAZZ/pLUIMtfkhpk+UtSgyx/SWqQ5S9JDfJQz0Z5SKjUtomXf5JrgQ8A5wG/X1W7J51BK2fQJx3wiUc6lya62yfJecDvAm8GLgPenuSySWaQJE3+lf9VwJGq+jpAkgVgK/DIhHNoQMu9Un/3lSe5aYhX8ufKoO8mBs076DuOYd7FDMJ3OpqkVNXkVpb8Y+Daqvrn3eVfAv52Vb3jtHk7gB3dxdcAjw25qkuAPxkz7qSZ+dybtrxg5kmZtsyD5P3pqnrVUldO+pV/Fhn7S88+VbUH2DPySpIvVtXcqMuvBDOfe9OWF8w8KdOW+WzknfShnseAS/surweennAGSWrepMv//wCbkmxM8hJgG7B/whkkqXkT3e1TVSeTvAP47/QO9fxIVR0+B6saeZfRCjLzuTdtecHMkzJtmcfOO9EPfCVJq4M/7yBJDbL8JalBL7ryT3JtkseSHEmyc6XznC7JpUnuS/JoksNJ3tmN/0aSp5I81P29ZaWz9ktyNMmhLtsXu7GLk9yT5PHudM1K5zwlyWv6tuVDSb6b5F2rbTsn+UiSZ5M83De25HZNcmv32H4syTWrJO9vJflqkq8k+XSSi7rxDUl+0Letf2/SeZfJvOTjYKW38TKZP9GX92iSh7rx0bZzVb1o/uh9iPw14NXAS4AvA5etdK7TMq4FXtudfwXwR/R+6uI3gF9f6XzL5D4KXHLa2L8HdnbndwK/udI5l3lc/DHw06ttOwNvBF4LPHym7do9Tr4MXABs7B7r562CvH8fOL87/5t9eTf0z1tl23jRx8Fq2MZLZT7t+vcD/26c7fxie+X/w5+PqKo/A079fMSqUVXPVNWD3fnvAY8C61Y21ci2Anu783uB61cuyrKuBr5WVf93pYOcrqo+B3zrtOGltutWYKGqnq+qJ4Aj9B7zE7NY3qr6bFWd7C7eT+/7O6vGEtt4KSu+jWH5zEkC/BPg4+Os48VW/uuAb/RdPsYqLtYkG4CfB77QDb2je+v8kdW0C6VTwGeTHOx+fgNgtqqegd6TGvBTK5Zuedv48f9RVvN2hqW36zQ8vv8ZcHff5Y1JvpTkfyV5w0qFWsJij4Np2MZvAI5X1eN9Y0Nv5xdb+Q/08xGrQZIZ4JPAu6rqu8CHgZ8Bfg54ht7butXk9VX1Wnq/yHpLkjeudKBBdF8mfBvwh93Qat/Oy1nVj+8k7wVOAh/rhp4B/npV/Tzwr4H/nOSVK5XvNEs9Dlb1Nu68nR9/MTPSdn6xlf9U/HxEkp+gV/wfq6pPAVTV8ap6oar+AviPrMBbzeVU1dPd6bPAp+nlO55kLUB3+uzKJVzSm4EHq+o4rP7t3Flqu67ax3eS7cB1wI3V7Yjudp18szt/kN7+859duZQ/sszjYNVuY4Ak5wP/EPjEqbFRt/OLrfxX/c9HdPvrbgMerarf7htf2zftHwAPn77sSkny8iSvOHWe3gd8D9Pbttu7aduBO1cm4bJ+7FXSat7OfZbarvuBbUkuSLIR2AQ8sAL5fkx6/0DTe4C3VdWf9o2/Kr1/w4Mkr6aX9+srk/LHLfM4WJXbuM+bgK9W1bFTAyNv50l/ij2BT8nfQu8Imq8B713pPIvk+3v03kZ+BXio+3sL8J+AQ934fmDtSmfty/xqekdAfBk4fGq7An8VuBd4vDu9eKWznpb7ZcA3gQv7xlbVdqb3xPQM8Of0XnXevNx2Bd7bPbYfA968SvIeobef/NTj+fe6uf+oe7x8GXgQ+IVVtI2XfBys9DZeKnM3/lHgV06bO9J29ucdJKlBL7bdPpKkAVj+ktQgy1+SGmT5S1KDLH9JapDlL0kNsvwlqUH/H/2NyN+cyganAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# get length of all the messages in the train set\n",
    "seq_len = [len(i.split()) for i in train_text]\n",
    "\n",
    "pd.Series(seq_len).hist(bins = 30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "94d1d10e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "max_seq_len = 25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "bd2098fa",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/xuming/opt/anaconda3/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:2198: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# tokenize and encode sequences in the training set\n",
    "tokens_train = tokenizer.batch_encode_plus(\n",
    "    train_text.tolist(),\n",
    "    max_length = max_seq_len,\n",
    "    pad_to_max_length=True,\n",
    "    truncation=True,\n",
    "    return_token_type_ids=False\n",
    ")\n",
    "\n",
    "# tokenize and encode sequences in the validation set\n",
    "tokens_val = tokenizer.batch_encode_plus(\n",
    "    val_text.tolist(),\n",
    "    max_length = max_seq_len,\n",
    "    pad_to_max_length=True,\n",
    "    truncation=True,\n",
    "    return_token_type_ids=False\n",
    ")\n",
    "\n",
    "# tokenize and encode sequences in the test set\n",
    "tokens_test = tokenizer.batch_encode_plus(\n",
    "    test_text.tolist(),\n",
    "    max_length = max_seq_len,\n",
    "    pad_to_max_length=True,\n",
    "    truncation=True,\n",
    "    return_token_type_ids=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6acf47f7",
   "metadata": {},
   "source": [
    "# Convert Integer Sequences to Tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8a0e94a1",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# for train set\n",
    "train_seq = torch.tensor(tokens_train['input_ids'])\n",
    "train_mask = torch.tensor(tokens_train['attention_mask'])\n",
    "train_y = torch.tensor(train_labels.tolist())\n",
    "\n",
    "# for validation set\n",
    "val_seq = torch.tensor(tokens_val['input_ids'])\n",
    "val_mask = torch.tensor(tokens_val['attention_mask'])\n",
    "val_y = torch.tensor(val_labels.tolist())\n",
    "\n",
    "# for test set\n",
    "test_seq = torch.tensor(tokens_test['input_ids'])\n",
    "test_mask = torch.tensor(tokens_test['attention_mask'])\n",
    "test_y = torch.tensor(test_labels.tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e000817d",
   "metadata": {},
   "source": [
    "# Create DataLoaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "eabf6277",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n",
    "\n",
    "#define a batch size\n",
    "batch_size = 32\n",
    "\n",
    "# wrap tensors\n",
    "train_data = TensorDataset(train_seq, train_mask, train_y)\n",
    "\n",
    "# sampler for sampling the data during training\n",
    "train_sampler = RandomSampler(train_data)\n",
    "\n",
    "# dataLoader for train set\n",
    "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n",
    "\n",
    "# wrap tensors\n",
    "val_data = TensorDataset(val_seq, val_mask, val_y)\n",
    "\n",
    "# sampler for sampling the data during training\n",
    "val_sampler = SequentialSampler(val_data)\n",
    "\n",
    "# dataLoader for validation set\n",
    "val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdf36387",
   "metadata": {},
   "source": [
    "# Freeze BERT Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "aba96f23",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# freeze all the parameters\n",
    "for param in bert.parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abec1aad",
   "metadata": {},
   "source": [
    "# Define Model Architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "98d0f697",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "class BERT_Arch(nn.Module):\n",
    "\n",
    "    def __init__(self, bert):\n",
    "\n",
    "      super(BERT_Arch, self).__init__()\n",
    "\n",
    "      self.bert = bert\n",
    "\n",
    "      # dropout layer\n",
    "      self.dropout = nn.Dropout(p=0.1)\n",
    "\n",
    "      # relu activation function\n",
    "      self.relu =  nn.ReLU()\n",
    "\n",
    "      # dense layer 1\n",
    "      self.fc1 = nn.Linear(768,512)\n",
    "\n",
    "      # dense layer 2 (Output layer)\n",
    "      self.fc2 = nn.Linear(512,2)\n",
    "\n",
    "      #softmax activation function\n",
    "      self.softmax = nn.LogSoftmax(dim=1)\n",
    "\n",
    "    #define the forward pass\n",
    "    def forward(self, sent_id, mask):\n",
    "\n",
    "      #pass the inputs to the model\n",
    "      _, cls_hs = self.bert(sent_id, attention_mask=mask, return_dict=False)\n",
    "\n",
    "      x = self.fc1(cls_hs)\n",
    "\n",
    "      x = self.relu(x)\n",
    "\n",
    "      x = self.dropout(x)\n",
    "\n",
    "      # output layer\n",
    "      x = self.fc2(x)\n",
    "\n",
    "      # apply softmax activation\n",
    "      x = self.softmax(x)\n",
    "\n",
    "      return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "6786138a",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# pass the pre-trained BERT to our define architecture\n",
    "model = BERT_Arch(bert)\n",
    "\n",
    "# push the model to GPU\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "f31ed040",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# optimizer from hugging face transformers\n",
    "from transformers import AdamW\n",
    "\n",
    "# define the optimizer\n",
    "optimizer = AdamW(model.parameters(), lr = 1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50257fc5",
   "metadata": {},
   "source": [
    "# Find Class Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "ceeb2377",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.57743559 3.72848948]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/xuming/opt/anaconda3/lib/python3.8/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass classes=[0 1], y=1188    0\n",
      "5403    0\n",
      "2924    0\n",
      "880     1\n",
      "719     0\n",
      "       ..\n",
      "5121    0\n",
      "1006    1\n",
      "4374    1\n",
      "903     0\n",
      "5057    0\n",
      "Name: label, Length: 3900, dtype: int64 as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error\n",
      "  warnings.warn(f\"Pass {args_msg} as keyword args. From version \"\n"
     ]
    }
   ],
   "source": [
    "from sklearn.utils.class_weight import compute_class_weight\n",
    "\n",
    "#compute the class weights\n",
    "class_wts = compute_class_weight('balanced', np.unique(train_labels), train_labels)\n",
    "\n",
    "print(class_wts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "814c356b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# convert class weights to tensor\n",
    "weights= torch.tensor(class_wts,dtype=torch.float)\n",
    "weights = weights.to(device)\n",
    "\n",
    "# loss function\n",
    "cross_entropy  = nn.NLLLoss(weight=weights)\n",
    "\n",
    "# number of training epochs\n",
    "epochs = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c462778d",
   "metadata": {},
   "source": [
    "# Fine-Tune BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "d758b8dc",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# function to train the model\n",
    "def train():\n",
    "\n",
    "  model.train()\n",
    "\n",
    "  total_loss, total_accuracy = 0, 0\n",
    "\n",
    "  # empty list to save model predictions\n",
    "  total_preds=[]\n",
    "\n",
    "  # iterate over batches\n",
    "  for step,batch in enumerate(train_dataloader):\n",
    "\n",
    "    # progress update after every 50 batches.\n",
    "    if step % 50 == 0 and not step == 0:\n",
    "      print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(train_dataloader)))\n",
    "\n",
    "    # push the batch to gpu\n",
    "    batch = [r.to(device) for r in batch]\n",
    "\n",
    "    sent_id, mask, labels = batch\n",
    "\n",
    "    # clear previously calculated gradients\n",
    "    model.zero_grad()\n",
    "\n",
    "    # get model predictions for the current batch\n",
    "    preds = model(sent_id, mask)\n",
    "\n",
    "    # compute the loss between actual and predicted values\n",
    "    loss = cross_entropy(preds, labels)\n",
    "\n",
    "    # add on to the total loss\n",
    "    total_loss = total_loss + loss.item()\n",
    "\n",
    "    # backward pass to calculate the gradients\n",
    "    loss.backward()\n",
    "\n",
    "    # clip the the gradients to 1.0. It helps in preventing the exploding gradient problem\n",
    "    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "\n",
    "    # update parameters\n",
    "    optimizer.step()\n",
    "\n",
    "    # model predictions are stored on GPU. So, push it to CPU\n",
    "    preds=preds.detach().cpu().numpy()\n",
    "\n",
    "    # append the model predictions\n",
    "    total_preds.append(preds)\n",
    "\n",
    "  # compute the training loss of the epoch\n",
    "  avg_loss = total_loss / len(train_dataloader)\n",
    "\n",
    "  # predictions are in the form of (no. of batches, size of batch, no. of classes).\n",
    "  # reshape the predictions in form of (number of samples, no. of classes)\n",
    "  total_preds  = np.concatenate(total_preds, axis=0)\n",
    "\n",
    "  #returns the loss and predictions\n",
    "  return avg_loss, total_preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "45ba906e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# function for evaluating the model\n",
    "def evaluate():\n",
    "\n",
    "  print(\"\\nEvaluating...\")\n",
    "\n",
    "  # deactivate dropout layers\n",
    "  model.eval()\n",
    "\n",
    "  total_loss, total_accuracy = 0, 0\n",
    "\n",
    "  # empty list to save the model predictions\n",
    "  total_preds = []\n",
    "\n",
    "  # iterate over batches\n",
    "  for step,batch in enumerate(val_dataloader):\n",
    "\n",
    "    # Progress update every 50 batches.\n",
    "    if step % 50 == 0 and not step == 0:\n",
    "\n",
    "      # Calculate elapsed time in minutes.\n",
    "      elapsed = format_time(time.time() - t0)\n",
    "\n",
    "      # Report progress.\n",
    "      print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(val_dataloader)))\n",
    "\n",
    "    # push the batch to gpu\n",
    "    batch = [t.to(device) for t in batch]\n",
    "\n",
    "    sent_id, mask, labels = batch\n",
    "\n",
    "    # deactivate autograd\n",
    "    with torch.no_grad():\n",
    "\n",
    "      # model predictions\n",
    "      preds = model(sent_id, mask)\n",
    "\n",
    "      # compute the validation loss between actual and predicted values\n",
    "      loss = cross_entropy(preds,labels)\n",
    "\n",
    "      total_loss = total_loss + loss.item()\n",
    "\n",
    "      preds = preds.detach().cpu().numpy()\n",
    "\n",
    "      total_preds.append(preds)\n",
    "\n",
    "  # compute the validation loss of the epoch\n",
    "  avg_loss = total_loss / len(val_dataloader)\n",
    "\n",
    "  # reshape the predictions in form of (number of samples, no. of classes)\n",
    "  total_preds  = np.concatenate(total_preds, axis=0)\n",
    "\n",
    "  return avg_loss, total_preds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cff6337b",
   "metadata": {},
   "source": [
    "# Start Model Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c3ce0c2",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# set initial loss to infinite\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "# empty lists to store training and validation loss of each epoch\n",
    "train_losses=[]\n",
    "valid_losses=[]\n",
    "\n",
    "#for each epoch\n",
    "for epoch in range(epochs):\n",
    "\n",
    "    print('\\n Epoch {:} / {:}'.format(epoch + 1, epochs))\n",
    "\n",
    "    #train model\n",
    "    train_loss, _ = train()\n",
    "\n",
    "    #evaluate model\n",
    "    valid_loss, _ = evaluate()\n",
    "\n",
    "    #save the best model\n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        torch.save(model.state_dict(), 'saved_weights.pt')\n",
    "\n",
    "    # append training and validation loss\n",
    "    train_losses.append(train_loss)\n",
    "    valid_losses.append(valid_loss)\n",
    "\n",
    "    print(f'\\nTraining Loss: {train_loss:.3f}')\n",
    "    print(f'Validation Loss: {valid_loss:.3f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de5fd546",
   "metadata": {},
   "source": [
    "# Load Saved Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4323d80",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#load weights of best model\n",
    "path = 'saved_weights.pt'\n",
    "model.load_state_dict(torch.load(path))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab275362",
   "metadata": {},
   "source": [
    "# Get Predictions for Test Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2ad1264",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# get predictions for test data\n",
    "with torch.no_grad():\n",
    "  preds = model(test_seq.to(device), test_mask.to(device))\n",
    "  preds = preds.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbccc2e5",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# model's performance\n",
    "preds = np.argmax(preds, axis = 1)\n",
    "print(classification_report(test_y, preds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc6fd88b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# confusion matrix\n",
    "pd.crosstab(test_y, preds)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83f537df",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Bert 蕴含推理任务\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54185ba5",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from transformers import pipeline\n",
    "\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "model_name = \"hfl/chinese-macbert-base\"\n",
    "\n",
    "nlp = pipeline(\"sentiment-analysis\",\n",
    "               model=model_name,\n",
    "               tokenizer=model_name,\n",
    "               device=-1,  # gpu device id\n",
    "               )\n",
    "\n",
    "result = nlp(\"我爱你\")[0]\n",
    "print(f\"label: {result['label']}, with score: {round(result['score'], 4)}\")\n",
    "\n",
    "result = nlp(\"我恨你\")[0]\n",
    "print(f\"label: {result['label']}, with score: {round(result['score'], 4)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6160dac1",
   "metadata": {},
   "source": [
    "模型默认保存在：`~/.cache/huggingface/transformers`\n",
    "\n",
    "`hfl/chinese-macbert-base`模型介绍：\n",
    "https://huggingface.co/hfl/chinese-macbert-base?text=%E5%B7%B4%E9%BB%8E%E6%98%AF%5BMASK%5D%E5%9B%BD%E7%9A%84%E9%A6%96%E9%83%BD%E3%80%82\n",
    "\n",
    "可以使用其他别人finetune过的分类模型，效果会更好。\n",
    "\n",
    "\n",
    "不通过pipeline，可以自己写预测逻辑："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6aa54972",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "import torch\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "print(\"token ok\")\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
    "print(\"model ok\")\n",
    "\n",
    "\n",
    "classes = [\"not paraphrase\", \"is paraphrase\"]\n",
    "sequence_0 = \"中国首都是北京\"\n",
    "sequence_1 = \"苹果有益于你的身体健康\"\n",
    "sequence_2 = \"北京是在北回归线附近的城市\"\n",
    "paraphrase = tokenizer(sequence_0, sequence_2, return_tensors=\"pt\")\n",
    "not_paraphrase = tokenizer(sequence_0, sequence_1, return_tensors=\"pt\")\n",
    "\n",
    "paraphrase_classification_logits = model(**paraphrase).logits\n",
    "not_paraphrase_classification_logits = model(**not_paraphrase).logits\n",
    "paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]\n",
    "not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0]\n",
    "\n",
    "# Should be paraphrase\n",
    "print(\"-\"*42)\n",
    "print(sequence_0 + ' + ' + sequence_2 + ', paraphrase proba:')\n",
    "for i in range(len(classes)):\n",
    "    print(f\"{classes[i]}: {int(round(paraphrase_results[i] * 100))}%\")\n",
    "    \n",
    "# Should not be paraphrase\n",
    "print(\"-\"*42)\n",
    "print(sequence_0 + ' + ' + sequence_1 + ', paraphrase proba:')\n",
    "for i in range(len(classes)):\n",
    "    print(f\"{classes[i]}: {int(round(not_paraphrase_results[i] * 100))}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45d4dca6",
   "metadata": {},
   "source": [
    "本节完。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75945ed2",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}